Compare commits

...

4 Commits

Author SHA1 Message Date
Erik
b3c8fd7249 Remove _get_tracked_value from base class 2026-03-23 07:53:02 +01:00
Erik
4fc68b0adf Add test 2026-03-22 18:11:12 +01:00
Erik
5bbf0d2dec Fix cover triggers 2026-03-22 17:29:06 +01:00
Erik
7f453b56ad Guard against unexpected type in triggers 2026-03-22 16:18:19 +01:00
7 changed files with 60 additions and 24 deletions

View File

@@ -5,14 +5,14 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
StringEntityTriggerBase,
Trigger,
)
from . import DOMAIN
class ButtonPressedTrigger(EntityTriggerBase):
class ButtonPressedTrigger(StringEntityTriggerBase):
"""Trigger for button entity presses."""
_domain_specs = {DOMAIN: DomainSpec()}

View File

@@ -2,13 +2,13 @@
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.trigger import EntityTriggerBase, Trigger
from homeassistant.helpers.trigger import StringEntityTriggerBase, Trigger
from .const import ATTR_IS_CLOSED, DOMAIN, CoverDeviceClass
from .models import CoverDomainSpec
class CoverTriggerBase(EntityTriggerBase[CoverDomainSpec]):
class CoverTriggerBase(StringEntityTriggerBase[CoverDomainSpec]):
"""Base trigger for cover state changes."""
def _get_value(self, state: State) -> str | bool | None:

View File

@@ -5,14 +5,14 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
StringEntityTriggerBase,
Trigger,
)
from . import DOMAIN
class SceneActivatedTrigger(EntityTriggerBase):
class SceneActivatedTrigger(StringEntityTriggerBase):
"""Trigger for scene entity activations."""
_domain_specs = {DOMAIN: DomainSpec()}

View File

@@ -6,14 +6,14 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
StringEntityTriggerBase,
Trigger,
)
from .const import DOMAIN
class SelectionChangedTrigger(EntityTriggerBase):
class SelectionChangedTrigger(StringEntityTriggerBase):
"""Trigger for select entity when its selection changes."""
_domain_specs = {DOMAIN: DomainSpec(), INPUT_SELECT_DOMAIN: DomainSpec()}

View File

@@ -5,14 +5,14 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
StringEntityTriggerBase,
Trigger,
)
from .const import DOMAIN
class TextChangedTrigger(EntityTriggerBase):
class TextChangedTrigger(StringEntityTriggerBase):
"""Trigger for text entity when its content changes."""
_domain_specs = {DOMAIN: DomainSpec()}

View File

@@ -363,13 +363,6 @@ class EntityTriggerBase[DomainSpecT: DomainSpec = DomainSpec](Trigger):
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
def _get_tracked_value(self, state: State) -> Any:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[state.domain]
if domain_spec.value_source is None:
return state.state
return state.attributes.get(domain_spec.value_source)
@abc.abstractmethod
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
@@ -452,7 +445,23 @@ class EntityTriggerBase[DomainSpecT: DomainSpec = DomainSpec](Trigger):
)
class EntityTargetStateTriggerBase(EntityTriggerBase):
class StringEntityTriggerBase[DomainSpecT: DomainSpec = DomainSpec](
EntityTriggerBase[DomainSpecT]
):
"""Trigger for string based entity state changes."""
def _get_tracked_value(self, state: State) -> str | None:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[state.domain]
if domain_spec.value_source is None:
return state.state
value = state.attributes.get(domain_spec.value_source)
if not isinstance(value, str):
return None
return value
class EntityTargetStateTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes to a specific state.
Uses _get_tracked_value to extract the value, so it works for both
@@ -477,7 +486,7 @@ class EntityTargetStateTriggerBase(EntityTriggerBase):
return self._get_tracked_value(state) in self._to_states
class EntityTransitionTriggerBase(EntityTriggerBase):
class EntityTransitionTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes between specific states."""
_from_states: set[str]
@@ -499,7 +508,7 @@ class EntityTransitionTriggerBase(EntityTriggerBase):
return self._get_tracked_value(state) in self._to_states
class EntityOriginStateTriggerBase(EntityTriggerBase):
class EntityOriginStateTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes from a specific state."""
_from_state: str

View File

@@ -52,8 +52,8 @@ from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS,
EntityNumericalStateChangedTriggerWithUnitBase,
EntityNumericalStateCrossedThresholdTriggerWithUnitBase,
EntityTriggerBase,
PluggableAction,
StringEntityTriggerBase,
Trigger,
TriggerActionRunner,
TriggerConfig,
@@ -2395,10 +2395,10 @@ async def test_numerical_state_attribute_crossed_threshold_with_unit_error_handl
def _make_trigger(
hass: HomeAssistant, domain_specs: Mapping[str, DomainSpec]
) -> EntityTriggerBase:
"""Create a minimal EntityTriggerBase subclass with the given domain specs."""
) -> StringEntityTriggerBase:
"""Create a minimal StringEntityTriggerBase subclass with the given domain specs."""
class _SimpleTrigger(EntityTriggerBase):
class _SimpleTrigger(StringEntityTriggerBase):
"""Minimal concrete trigger for testing entity_filter."""
_domain_specs = domain_specs
@@ -2580,6 +2580,33 @@ async def test_make_entity_target_state_trigger(
assert not trig.is_valid_state(wrong_value_state)
@pytest.mark.parametrize(
"attribute_value",
[
pytest.param(["a", "b"], id="list"),
pytest.param({"key": "value"}, id="dict"),
pytest.param(123, id="int"),
pytest.param(None, id="none"),
],
)
async def test_string_entity_trigger_base_non_string_attribute(
hass: HomeAssistant,
attribute_value: Any,
) -> None:
"""Test that attribute-based triggers handle non-string attribute values gracefully."""
trigger_cls = make_entity_target_state_trigger(
{"light": DomainSpec(value_source="effect")}, to_states={"rainbow"}
)
config = TriggerConfig(key="light.test", target={"entity_id": "light.bed"})
trig = trigger_cls(hass, config)
state_with_unhashable = State("light.bed", "on", {"effect": attribute_value})
# Non-string attribute values should not raise and should not match
assert not trig.is_valid_state(state_with_unhashable)
@pytest.mark.parametrize(
(
"domain_specs",