Compare commits

...

7 Commits

Author SHA1 Message Date
Erik
17fdf8fe4c Adjust allowed types 2026-04-08 09:06:31 +02:00
Erik
bacce25235 Revert unnecessary inheritance of StringEntityTriggerBase 2026-04-08 09:04:32 +02:00
Erik
628739e34e Merge remote-tracking branch 'upstream/dev' into trigger_guard_against_non_hashable 2026-04-08 08:48:37 +02:00
Erik
b3c8fd7249 Remove _get_tracked_value from base class 2026-03-23 07:53:02 +01:00
Erik
4fc68b0adf Add test 2026-03-22 18:11:12 +01:00
Erik
5bbf0d2dec Fix cover triggers 2026-03-22 17:29:06 +01:00
Erik
7f453b56ad Guard against unexpected type in triggers 2026-03-22 16:18:19 +01:00
2 changed files with 44 additions and 10 deletions

View File

@@ -371,13 +371,6 @@ class EntityTriggerBase(Trigger):
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
def _get_tracked_value(self, state: State) -> Any:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[state.domain]
if domain_spec.value_source is None:
return state.state
return state.attributes.get(domain_spec.value_source)
@abc.abstractmethod
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
@@ -460,7 +453,21 @@ class EntityTriggerBase(Trigger):
)
class EntityTargetStateTriggerBase(EntityTriggerBase):
class StringEntityTriggerBase(EntityTriggerBase):
"""Trigger for string based entity state changes."""
def _get_tracked_value(self, state: State) -> bool | str | None:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[state.domain]
if domain_spec.value_source is None:
return state.state
value = state.attributes.get(domain_spec.value_source)
if not isinstance(value, (bool, str)):
return None
return value
class EntityTargetStateTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes to a specific state.
Uses _get_tracked_value to extract the value, so it works for both
@@ -485,7 +492,7 @@ class EntityTargetStateTriggerBase(EntityTriggerBase):
return self._get_tracked_value(state) in self._to_states
class EntityTransitionTriggerBase(EntityTriggerBase):
class EntityTransitionTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes between specific states."""
_from_states: set[str | bool]
@@ -507,7 +514,7 @@ class EntityTransitionTriggerBase(EntityTriggerBase):
return self._get_tracked_value(state) in self._to_states
class EntityOriginStateTriggerBase(EntityTriggerBase):
class EntityOriginStateTriggerBase(StringEntityTriggerBase):
"""Trigger for entity state changes from a specific state."""
_from_state: str

View File

@@ -2960,6 +2960,33 @@ async def test_make_entity_target_state_trigger(
assert not trig.is_valid_state(wrong_value_state)
@pytest.mark.parametrize(
"attribute_value",
[
pytest.param(["a", "b"], id="list"),
pytest.param({"key": "value"}, id="dict"),
pytest.param(123, id="int"),
pytest.param(None, id="none"),
],
)
async def test_string_entity_trigger_base_non_string_attribute(
hass: HomeAssistant,
attribute_value: Any,
) -> None:
"""Test that attribute-based triggers handle non-string attribute values gracefully."""
trigger_cls = make_entity_target_state_trigger(
{"light": DomainSpec(value_source="effect")}, to_states={"rainbow"}
)
config = TriggerConfig(key="light.test", target={"entity_id": "light.bed"})
trig = trigger_cls(hass, config)
state_with_unhashable = State("light.bed", "on", {"effect": attribute_value})
# Non-string attribute values should not raise and should not match
assert not trig.is_valid_state(state_with_unhashable)
@pytest.mark.parametrize(
(
"domain_specs",