From 30fdfc454f9cd6a1a0abc356a46105042a462cdd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 May 2022 05:48:38 -0500 Subject: [PATCH] Avoid lowercasing entities after template ratelimit recovery (#71415) --- homeassistant/helpers/event.py | 28 ++++++++++--- tests/helpers/test_event.py | 77 ++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 12 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 686fde89fbb..c1229dc3e7c 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -68,8 +68,8 @@ class TrackStates: """Class for keeping track of states being tracked. all_states: All states on the system are being tracked - entities: Entities to track - domains: Domains to track + entities: Lowercased entities to track + domains: Lowercased domains to track """ all_states: bool @@ -248,7 +248,16 @@ def async_track_state_change_event( """ if not (entity_ids := _async_string_to_lower_list(entity_ids)): return _remove_empty_listener + return _async_track_state_change_event(hass, entity_ids, action) + +@bind_hass +def _async_track_state_change_event( + hass: HomeAssistant, + entity_ids: str | Iterable[str], + action: Callable[[Event], Any], +) -> CALLBACK_TYPE: + """async_track_state_change_event without lowercasing.""" entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {}) if TRACK_STATE_CHANGE_LISTENER not in hass.data: @@ -419,7 +428,16 @@ def async_track_state_added_domain( """Track state change events when an entity is added to domains.""" if not (domains := _async_string_to_lower_list(domains)): return _remove_empty_listener + return _async_track_state_added_domain(hass, domains, action) + +@bind_hass +def _async_track_state_added_domain( + hass: HomeAssistant, + domains: str | Iterable[str], + action: Callable[[Event], Any], +) -> CALLBACK_TYPE: + """async_track_state_added_domain without lowercasing.""" domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}) if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: @@ -626,7 +644,7 @@ class _TrackStateChangeFiltered: if not entities: return - self._listeners[_ENTITIES_LISTENER] = async_track_state_change_event( + self._listeners[_ENTITIES_LISTENER] = _async_track_state_change_event( self.hass, entities, self._action ) @@ -643,7 +661,7 @@ class _TrackStateChangeFiltered: if not domains: return - self._listeners[_DOMAINS_LISTENER] = async_track_state_added_domain( + self._listeners[_DOMAINS_LISTENER] = _async_track_state_added_domain( self.hass, domains, self._state_added ) @@ -1217,7 +1235,7 @@ def async_track_same_state( else: async_remove_state_for_cancel = async_track_state_change_event( hass, - [entity_ids] if isinstance(entity_ids, str) else entity_ids, + entity_ids, state_for_cancel_listener, ) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 7644abaa558..9b0a3e1abd5 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -1983,6 +1983,69 @@ async def test_track_template_result_and_conditional(hass): assert specific_runs[2] == "on" +async def test_track_template_result_and_conditional_upper_case(hass): + """Test tracking template with an and conditional with an upper case template.""" + specific_runs = [] + hass.states.async_set("light.a", "off") + hass.states.async_set("light.b", "off") + template_str = '{% if states.light.A.state == "on" and states.light.B.state == "on" %}on{% else %}off{% endif %}' + + template = Template(template_str, hass) + + def specific_run_callback(event, updates): + specific_runs.append(updates.pop().result) + + info = async_track_template_result( + hass, [TrackTemplate(template, None)], specific_run_callback + ) + await hass.async_block_till_done() + assert info.listeners == { + "all": False, + "domains": set(), + "entities": {"light.a"}, + "time": False, + } + + hass.states.async_set("light.b", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 0 + + hass.states.async_set("light.a", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 1 + assert specific_runs[0] == "on" + assert info.listeners == { + "all": False, + "domains": set(), + "entities": {"light.a", "light.b"}, + "time": False, + } + + hass.states.async_set("light.b", "off") + await hass.async_block_till_done() + assert len(specific_runs) == 2 + assert specific_runs[1] == "off" + assert info.listeners == { + "all": False, + "domains": set(), + "entities": {"light.a", "light.b"}, + "time": False, + } + + hass.states.async_set("light.a", "off") + await hass.async_block_till_done() + assert len(specific_runs) == 2 + + hass.states.async_set("light.b", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 2 + + hass.states.async_set("light.a", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 3 + assert specific_runs[2] == "on" + + async def test_track_template_result_iterator(hass): """Test tracking template.""" iterator_runs = [] @@ -2187,7 +2250,7 @@ async def test_track_template_rate_limit(hass): assert refresh_runs == [0] info.async_refresh() assert refresh_runs == [0, 1] - hass.states.async_set("sensor.two", "any") + hass.states.async_set("sensor.TWO", "any") await hass.async_block_till_done() assert refresh_runs == [0, 1] next_time = dt_util.utcnow() + timedelta(seconds=0.125) @@ -2200,7 +2263,7 @@ async def test_track_template_rate_limit(hass): hass.states.async_set("sensor.three", "any") await hass.async_block_till_done() assert refresh_runs == [0, 1, 2] - hass.states.async_set("sensor.four", "any") + hass.states.async_set("sensor.fOuR", "any") await hass.async_block_till_done() assert refresh_runs == [0, 1, 2] next_time = dt_util.utcnow() + timedelta(seconds=0.125 * 2) @@ -2385,7 +2448,7 @@ async def test_track_template_rate_limit_super_3(hass): await hass.async_block_till_done() assert refresh_runs == [] - hass.states.async_set("sensor.one", "any") + hass.states.async_set("sensor.ONE", "any") await hass.async_block_till_done() assert refresh_runs == [] info.async_refresh() @@ -2408,7 +2471,7 @@ async def test_track_template_rate_limit_super_3(hass): hass.states.async_set("sensor.four", "any") await hass.async_block_till_done() assert refresh_runs == [1, 2] - hass.states.async_set("sensor.five", "any") + hass.states.async_set("sensor.FIVE", "any") await hass.async_block_till_done() assert refresh_runs == [1, 2] next_time = dt_util.utcnow() + timedelta(seconds=0.125 * 2) @@ -2453,7 +2516,7 @@ async def test_track_template_rate_limit_suppress_listener(hass): await hass.async_block_till_done() assert refresh_runs == [0] - hass.states.async_set("sensor.one", "any") + hass.states.async_set("sensor.oNe", "any") await hass.async_block_till_done() assert refresh_runs == [0] info.async_refresh() @@ -2482,7 +2545,7 @@ async def test_track_template_rate_limit_suppress_listener(hass): "time": False, } assert refresh_runs == [0, 1, 2] - hass.states.async_set("sensor.three", "any") + hass.states.async_set("sensor.Three", "any") await hass.async_block_till_done() assert refresh_runs == [0, 1, 2] hass.states.async_set("sensor.four", "any") @@ -2509,7 +2572,7 @@ async def test_track_template_rate_limit_suppress_listener(hass): "time": False, } assert refresh_runs == [0, 1, 2, 4] - hass.states.async_set("sensor.five", "any") + hass.states.async_set("sensor.Five", "any") await hass.async_block_till_done() # Rate limit hit and the all listener is shut off assert info.listeners == {