diff --git a/homeassistant/components/light/trigger.py b/homeassistant/components/light/trigger.py index df0e58417a2..321efb00755 100644 --- a/homeassistant/components/light/trigger.py +++ b/homeassistant/components/light/trigger.py @@ -1,32 +1,42 @@ """Provides triggers for lights.""" -from typing import cast, override +from typing import Final, cast, override import voluptuous as vol -from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM, CONF_STATE, MATCH_ALL -from homeassistant.core import ( - CALLBACK_TYPE, - Event, - EventStateChangedData, - HassJob, - HomeAssistant, - callback, +from homeassistant.const import ( + ATTR_ENTITY_ID, + CONF_PLATFORM, + CONF_STATE, + STATE_OFF, + STATE_ON, ) +from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.event import process_state_match -from homeassistant.helpers.target import async_track_target_selector_state_change_event +from homeassistant.helpers.target import ( + TargetStateChangedData, + async_track_target_selector_state_change_event, +) from homeassistant.helpers.trigger import Trigger, TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType from .const import DOMAIN +ATTR_BEHAVIOR: Final = "behavior" +BEHAVIOR_FIRST = "first" +BEHAVIOR_LAST = "last" +BEHAVIOR_ANY = "any" + STATE_PLATFORM_TYPE = f"{DOMAIN}.state" STATE_TRIGGER_SCHEMA = vol.All( cv.TRIGGER_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): STATE_PLATFORM_TYPE, - vol.Optional(CONF_STATE, default=MATCH_ALL): vol.Any(str, [str], None), + vol.Required(CONF_STATE): vol.In([STATE_ON, STATE_OFF]), + vol.Required(ATTR_BEHAVIOR, default=BEHAVIOR_ANY): vol.In( + [BEHAVIOR_FIRST, BEHAVIOR_LAST, BEHAVIOR_ANY] + ), **cv.ENTITY_SERVICE_FIELDS, }, ), @@ -60,19 +70,53 @@ class StateTrigger(Trigger): job = HassJob(action, f"light state trigger {trigger_info}") trigger_data = trigger_info["trigger_data"] - match_state = process_state_match(self.config.get(CONF_STATE)) + behavior = self.config.get(ATTR_BEHAVIOR) + match_config_state = process_state_match(self.config.get(CONF_STATE)) + + def check_all_match(entity_ids: set[str]) -> bool: + """Check if all entity states match.""" + return all( + match_config_state(state.state) + for entity_id in entity_ids + if (state := self.hass.states.get(entity_id)) is not None + ) + + def check_one_match(entity_ids: set[str]) -> bool: + """Check that only one entity state matches.""" + return ( + sum( + match_config_state(state.state) + for entity_id in entity_ids + if (state := self.hass.states.get(entity_id)) is not None + ) + == 1 + ) @callback - def state_change_listener(event: Event[EventStateChangedData]) -> None: + def state_change_listener( + target_state_change_data: TargetStateChangedData, + ) -> None: """Listen for state changes and call action.""" + event = target_state_change_data.state_change_event entity_id = event.data["entity_id"] from_state = event.data["old_state"] to_state = event.data["new_state"] if to_state is None: return - if not match_state(to_state.state): + + # This check is required for "first" behavior, to check that it went from zero + # entities matching the state to one. Otherwise, if previously there were two + # entities on CONF_STATE and one changed, this would trigger. + # For "last" behavior it is not required, but serves as a quicker fail check. + if not match_config_state(to_state.state): return + if behavior == BEHAVIOR_LAST: + if not check_all_match(target_state_change_data.targeted_entity_ids): + return + elif behavior == BEHAVIOR_FIRST: + if not check_one_match(target_state_change_data.targeted_entity_ids): + return self.hass.async_run_hass_job( job, diff --git a/homeassistant/components/light/triggers.yaml b/homeassistant/components/light/triggers.yaml index e60e7eb2153..dd1c72e84b2 100644 --- a/homeassistant/components/light/triggers.yaml +++ b/homeassistant/components/light/triggers.yaml @@ -7,3 +7,12 @@ state: required: true selector: state: + behavior: + required: true + default: any + selector: + select: + options: + - first + - last + - any diff --git a/tests/components/light/test_trigger.py b/tests/components/light/test_trigger.py index 2579a7334b8..deb7c94765e 100644 --- a/tests/components/light/test_trigger.py +++ b/tests/components/light/test_trigger.py @@ -80,6 +80,14 @@ async def target_lights(hass: HomeAssistant) -> None: ) entity_reg.async_update_entity(light_label.entity_id, labels={label.label_id}) + # Return all available light entities + return [ + "light.standalone_light", + "light.label_light", + "light.area_light", + "light.device_light", + ] + @pytest.mark.usefixtures("target_lights") @pytest.mark.parametrize( @@ -103,7 +111,7 @@ async def test_light_state_trigger_behavior_any( entity_id: str, state: str, ) -> None: - """Test that the light state trigger fires when light state changes to a specific state.""" + """Test that the light state trigger fires when any light state changes to a specific state.""" await async_setup_component(hass, "light", {}) reverse_state = STATE_OFF if state == STATE_ON else STATE_ON @@ -113,42 +121,155 @@ async def test_light_state_trigger_behavior_any( hass, automation.DOMAIN, { - automation.DOMAIN: [ - { - "alias": "Trigger when state changes to specific state", - "trigger": { - CONF_PLATFORM: "light.state", - CONF_STATE: state, - **trigger_target_config, - }, - "action": { - "service": "test.automation", - "data_template": {CONF_ENTITY_ID: f"{entity_id}"}, - }, + automation.DOMAIN: { + "trigger": { + CONF_PLATFORM: "light.state", + CONF_STATE: state, + **trigger_target_config, }, - { - "alias": "Trigger when state changes to any state", - "trigger": { - CONF_PLATFORM: "light.state", - **trigger_target_config, - }, - "action": { - "service": "test.automation", - "data_template": {CONF_ENTITY_ID: f"{entity_id}"}, - }, + "action": { + "service": "test.automation", + "data_template": {CONF_ENTITY_ID: f"{entity_id}"}, }, - ] + } }, ) hass.states.async_set(entity_id, state) await hass.async_block_till_done() - assert len(service_calls) == 2 + assert len(service_calls) == 1 assert service_calls[0].data[CONF_ENTITY_ID] == entity_id - assert service_calls[1].data[CONF_ENTITY_ID] == entity_id service_calls.clear() hass.states.async_set(entity_id, reverse_state) await hass.async_block_till_done() + assert len(service_calls) == 0 + + +@pytest.mark.parametrize( + ("trigger_target_config", "entity_id"), + [ + ({CONF_ENTITY_ID: "light.standalone_light"}, "light.standalone_light"), + ({ATTR_LABEL_ID: "test_label"}, "light.label_light"), + ({ATTR_AREA_ID: "test_area"}, "light.area_light"), + ({ATTR_FLOOR_ID: "test_floor"}, "light.area_light"), + ({ATTR_LABEL_ID: "test_label"}, "light.device_light"), + ({ATTR_AREA_ID: "test_area"}, "light.device_light"), + ({ATTR_FLOOR_ID: "test_floor"}, "light.device_light"), + ({ATTR_DEVICE_ID: "test_device"}, "light.device_light"), + ], +) +@pytest.mark.parametrize("state", [STATE_ON, STATE_OFF]) +async def test_light_state_trigger_behavior_first( + hass: HomeAssistant, + service_calls: list[ServiceCall], + target_lights: list[str], + trigger_target_config: dict, + entity_id: str, + state: str, +) -> None: + """Test that the light state trigger fires when the first light changes to a specific state.""" + await async_setup_component(hass, "light", {}) + + reverse_state = STATE_OFF if state == STATE_ON else STATE_ON + for other_entity_id in target_lights: + hass.states.async_set(other_entity_id, reverse_state) + await hass.async_block_till_done() + + await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + CONF_PLATFORM: "light.state", + CONF_STATE: state, + "behavior": "first", + **trigger_target_config, + }, + "action": { + "service": "test.automation", + "data_template": {CONF_ENTITY_ID: f"{entity_id}"}, + }, + } + }, + ) + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() assert len(service_calls) == 1 assert service_calls[0].data[CONF_ENTITY_ID] == entity_id + service_calls.clear() + + # Triggering other lights should not cause any service calls after the first one + for other_entity_id in target_lights: + hass.states.async_set(other_entity_id, state) + await hass.async_block_till_done() + for other_entity_id in target_lights: + hass.states.async_set(other_entity_id, reverse_state) + await hass.async_block_till_done() + assert len(service_calls) == 0 + + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() + assert len(service_calls) == 1 + assert service_calls[0].data[CONF_ENTITY_ID] == entity_id + + +@pytest.mark.parametrize( + ("trigger_target_config", "entity_id"), + [ + ({CONF_ENTITY_ID: "light.standalone_light"}, "light.standalone_light"), + ({ATTR_LABEL_ID: "test_label"}, "light.label_light"), + ({ATTR_AREA_ID: "test_area"}, "light.area_light"), + ({ATTR_FLOOR_ID: "test_floor"}, "light.area_light"), + ({ATTR_LABEL_ID: "test_label"}, "light.device_light"), + ({ATTR_AREA_ID: "test_area"}, "light.device_light"), + ({ATTR_FLOOR_ID: "test_floor"}, "light.device_light"), + ({ATTR_DEVICE_ID: "test_device"}, "light.device_light"), + ], +) +@pytest.mark.parametrize("state", [STATE_ON, STATE_OFF]) +async def test_light_state_trigger_behavior_last( + hass: HomeAssistant, + service_calls: list[ServiceCall], + target_lights: list[str], + trigger_target_config: dict, + entity_id: str, + state: str, +) -> None: + """Test that the light state trigger fires when the last light changes to a specific state.""" + await async_setup_component(hass, "light", {}) + + reverse_state = STATE_OFF if state == STATE_ON else STATE_ON + for other_entity_id in target_lights: + hass.states.async_set(other_entity_id, reverse_state) + await hass.async_block_till_done() + + await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + CONF_PLATFORM: "light.state", + CONF_STATE: state, + "behavior": "last", + **trigger_target_config, + }, + "action": { + "service": "test.automation", + "data_template": {CONF_ENTITY_ID: f"{entity_id}"}, + }, + } + }, + ) + + target_lights.remove(entity_id) + for other_entity_id in target_lights: + hass.states.async_set(other_entity_id, state) + await hass.async_block_till_done() + assert len(service_calls) == 0 + + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() + assert len(service_calls) == 1