diff --git a/homeassistant/components/homeassistant/triggers/time.py b/homeassistant/components/homeassistant/triggers/time.py index adfe592319b..b636a7a3590 100644 --- a/homeassistant/components/homeassistant/triggers/time.py +++ b/homeassistant/components/homeassistant/triggers/time.py @@ -4,7 +4,14 @@ from functools import partial import voluptuous as vol -from homeassistant.const import CONF_AT, CONF_PLATFORM +from homeassistant.components import sensor +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + CONF_AT, + CONF_PLATFORM, + STATE_UNAVAILABLE, + STATE_UNKNOWN, +) from homeassistant.core import HassJob, callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.event import ( @@ -18,8 +25,8 @@ import homeassistant.util.dt as dt_util _TIME_TRIGGER_SCHEMA = vol.Any( cv.time, - vol.All(str, cv.entity_domain("input_datetime")), - msg="Expected HH:MM, HH:MM:SS or Entity ID from domain 'input_datetime'", + vol.All(str, cv.entity_domain(("input_datetime", "sensor"))), + msg="Expected HH:MM, HH:MM:SS or Entity ID with domain 'input_datetime' or 'sensor'", ) TRIGGER_SCHEMA = vol.Schema( @@ -60,14 +67,16 @@ async def async_attach_trigger(hass, config, action, automation_info): def update_entity_trigger(entity_id, new_state=None): """Update the entity trigger for the entity_id.""" # If a listener was already set up for entity, remove it. - remove = entities.get(entity_id) + remove = entities.pop(entity_id, None) if remove: remove() - removes.remove(remove) remove = None + if not new_state: + return + # Check state of entity. If valid, set up a listener. - if new_state: + if new_state.domain == "input_datetime": has_date = new_state.attributes["has_date"] if has_date: year = new_state.attributes["year"] @@ -111,16 +120,32 @@ async def async_attach_trigger(hass, config, action, automation_info): minute=minute, second=second, ) + elif ( + new_state.domain == "sensor" + and new_state.attributes.get(ATTR_DEVICE_CLASS) + == sensor.DEVICE_CLASS_TIMESTAMP + and new_state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + trigger_dt = dt_util.parse_datetime(new_state.state) + + if trigger_dt is not None and trigger_dt > dt_util.utcnow(): + remove = async_track_point_in_time( + hass, + partial( + time_automation_listener, + f"time set in {entity_id}", + entity_id=entity_id, + ), + trigger_dt, + ) # Was a listener set up? if remove: - removes.append(remove) - - entities[entity_id] = remove + entities[entity_id] = remove for at_time in config[CONF_AT]: if isinstance(at_time, str): - # input_datetime entity + # entity update_entity_trigger(at_time, new_state=hass.states.get(at_time)) else: # datetime.time @@ -144,6 +169,8 @@ async def async_attach_trigger(hass, config, action, automation_info): @callback def remove_track_time_changes(): """Remove tracked time changes.""" + for remove in entities.values(): + remove() for remove in removes: remove() diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index e78717ac609..5b2ad0da2ac 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -282,25 +282,36 @@ comp_entity_ids = vol.Any( ) -def entity_domain(domain: str) -> Callable[[Any], str]: +def entity_domain(domain: Union[str, List[str]]) -> Callable[[Any], str]: """Validate that entity belong to domain.""" + ent_domain = entities_domain(domain) def validate(value: Any) -> str: """Test if entity domain is domain.""" - ent_domain = entities_domain(domain) return ent_domain(value)[0] return validate -def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]: +def entities_domain( + domain: Union[str, List[str]] +) -> Callable[[Union[str, List]], List[str]]: """Validate that entities belong to domain.""" + if isinstance(domain, str): + + def check_invalid(val: str) -> bool: + return val != domain + + else: + + def check_invalid(val: str) -> bool: + return val not in domain def validate(values: Union[str, List]) -> List[str]: """Test if entity domain is domain.""" values = entity_ids(values) for ent_id in values: - if split_entity_id(ent_id)[0] != domain: + if check_invalid(split_entity_id(ent_id)[0]): raise vol.Invalid( f"Entity ID '{ent_id}' does not belong to domain '{domain}'" ) diff --git a/tests/components/homeassistant/triggers/test_time.py b/tests/components/homeassistant/triggers/test_time.py index 91fd57beed3..673d0231912 100644 --- a/tests/components/homeassistant/triggers/test_time.py +++ b/tests/components/homeassistant/triggers/test_time.py @@ -3,8 +3,8 @@ from datetime import timedelta import pytest -import homeassistant.components.automation as automation -from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_OFF +from homeassistant.components import automation, sensor +from homeassistant.const import ATTR_DEVICE_CLASS, ATTR_ENTITY_ID, SERVICE_TURN_OFF from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -391,3 +391,104 @@ async def test_untrack_time_change(hass): ) assert len(mock_track_time_change.mock_calls) == 3 + + +async def test_if_fires_using_at_sensor(hass, calls): + """Test for firing at sensor time.""" + now = dt_util.now() + + trigger_dt = now.replace(hour=5, minute=0, second=0, microsecond=0) + timedelta(2) + + hass.states.async_set( + "sensor.next_alarm", + trigger_dt.isoformat(), + {ATTR_DEVICE_CLASS: sensor.DEVICE_CLASS_TIMESTAMP}, + ) + + time_that_will_not_match_right_away = trigger_dt - timedelta(minutes=1) + + some_data = "{{ trigger.platform }}-{{ trigger.now.day }}-{{ trigger.now.hour }}-{{trigger.entity_id}}" + with patch( + "homeassistant.util.dt.utcnow", + return_value=dt_util.as_utc(time_that_will_not_match_right_away), + ): + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": {"platform": "time", "at": "sensor.next_alarm"}, + "action": { + "service": "test.automation", + "data_template": {"some": some_data}, + }, + } + }, + ) + await hass.async_block_till_done() + + async_fire_time_changed(hass, trigger_dt + timedelta(seconds=1)) + await hass.async_block_till_done() + + assert len(calls) == 1 + assert ( + calls[0].data["some"] + == f"time-{trigger_dt.day}-{trigger_dt.hour}-sensor.next_alarm" + ) + + trigger_dt += timedelta(days=1, hours=1) + + hass.states.async_set( + "sensor.next_alarm", + trigger_dt.isoformat(), + {ATTR_DEVICE_CLASS: sensor.DEVICE_CLASS_TIMESTAMP}, + ) + await hass.async_block_till_done() + + async_fire_time_changed(hass, trigger_dt + timedelta(seconds=1)) + await hass.async_block_till_done() + + assert len(calls) == 2 + assert ( + calls[1].data["some"] + == f"time-{trigger_dt.day}-{trigger_dt.hour}-sensor.next_alarm" + ) + + for broken in ("unknown", "unavailable", "invalid-ts"): + hass.states.async_set( + "sensor.next_alarm", + trigger_dt.isoformat(), + {ATTR_DEVICE_CLASS: sensor.DEVICE_CLASS_TIMESTAMP}, + ) + await hass.async_block_till_done() + hass.states.async_set( + "sensor.next_alarm", + broken, + {ATTR_DEVICE_CLASS: sensor.DEVICE_CLASS_TIMESTAMP}, + ) + await hass.async_block_till_done() + + async_fire_time_changed(hass, trigger_dt + timedelta(seconds=1)) + await hass.async_block_till_done() + + # We should not have listened to anything + assert len(calls) == 2 + + # Now without device class + hass.states.async_set( + "sensor.next_alarm", + trigger_dt.isoformat(), + {ATTR_DEVICE_CLASS: sensor.DEVICE_CLASS_TIMESTAMP}, + ) + await hass.async_block_till_done() + hass.states.async_set( + "sensor.next_alarm", + trigger_dt.isoformat(), + ) + await hass.async_block_till_done() + + async_fire_time_changed(hass, trigger_dt + timedelta(seconds=1)) + await hass.async_block_till_done() + + # We should not have listened to anything + assert len(calls) == 2 diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 693785f4ea7..c829d4413f0 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -179,15 +179,21 @@ def test_entity_domain(): """Test entity domain validation.""" schema = vol.Schema(cv.entity_domain("sensor")) - options = ("invalid_entity", "cover.demo") - - for value in options: + for value in ("invalid_entity", "cover.demo"): with pytest.raises(vol.MultipleInvalid): - print(value) schema(value) assert schema("sensor.LIGHT") == "sensor.light" + schema = vol.Schema(cv.entity_domain(("sensor", "binary_sensor"))) + + for value in ("invalid_entity", "cover.demo"): + with pytest.raises(vol.MultipleInvalid): + schema(value) + + assert schema("sensor.LIGHT") == "sensor.light" + assert schema("binary_sensor.LIGHT") == "binary_sensor.light" + def test_entities_domain(): """Test entities domain validation."""