mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 11:47:06 +00:00
StateMachine is now case insensitive for entity ids
This commit is contained in:
parent
df3ddea23c
commit
89a548252a
@ -285,6 +285,19 @@ class TestStateMachine(unittest.TestCase):
|
|||||||
self.assertEqual(1, len(specific_runs))
|
self.assertEqual(1, len(specific_runs))
|
||||||
self.assertEqual(3, len(wildcard_runs))
|
self.assertEqual(3, len(wildcard_runs))
|
||||||
|
|
||||||
|
def test_case_insensitivty(self):
|
||||||
|
runs = []
|
||||||
|
|
||||||
|
self.states.track_change(
|
||||||
|
'light.BoWl', lambda a, b, c: runs.append(1),
|
||||||
|
ha.MATCH_ALL, ha.MATCH_ALL)
|
||||||
|
|
||||||
|
self.states.set('light.BOWL', 'off')
|
||||||
|
self.bus._pool.block_till_done()
|
||||||
|
|
||||||
|
self.assertTrue(self.states.is_state('light.bowl', 'off'))
|
||||||
|
self.assertEqual(1, len(runs))
|
||||||
|
|
||||||
|
|
||||||
class TestServiceCall(unittest.TestCase):
|
class TestServiceCall(unittest.TestCase):
|
||||||
""" Test ServiceCall class. """
|
""" Test ServiceCall class. """
|
||||||
|
@ -15,6 +15,8 @@ import re
|
|||||||
import datetime as dt
|
import datetime as dt
|
||||||
import functools as ft
|
import functools as ft
|
||||||
|
|
||||||
|
from requests.structures import CaseInsensitiveDict
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||||
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
|
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
|
||||||
@ -482,15 +484,18 @@ class StateMachine(object):
|
|||||||
""" Helper class that tracks the state of different entities. """
|
""" Helper class that tracks the state of different entities. """
|
||||||
|
|
||||||
def __init__(self, bus):
|
def __init__(self, bus):
|
||||||
self._states = {}
|
self._states = CaseInsensitiveDict()
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def entity_ids(self, domain_filter=None):
|
def entity_ids(self, domain_filter=None):
|
||||||
""" List of entity ids that are being tracked. """
|
""" List of entity ids that are being tracked. """
|
||||||
if domain_filter is not None:
|
if domain_filter is not None:
|
||||||
return [entity_id for entity_id in self._states.keys()
|
domain_filter = domain_filter.lower()
|
||||||
if util.split_entity_id(entity_id)[0] == domain_filter]
|
|
||||||
|
return [state.entity_id for key, state
|
||||||
|
in self._states.lower_items()
|
||||||
|
if util.split_entity_id(key)[0] == domain_filter]
|
||||||
else:
|
else:
|
||||||
return list(self._states.keys())
|
return list(self._states.keys())
|
||||||
|
|
||||||
@ -524,9 +529,9 @@ class StateMachine(object):
|
|||||||
self._states[entity_id].state == state)
|
self._states[entity_id].state == state)
|
||||||
|
|
||||||
def remove(self, entity_id):
|
def remove(self, entity_id):
|
||||||
""" Removes a entity from the state machine.
|
""" Removes an entity from the state machine.
|
||||||
|
|
||||||
Returns boolean to indicate if a entity was removed. """
|
Returns boolean to indicate if an entity was removed. """
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._states.pop(entity_id, None) is not None
|
return self._states.pop(entity_id, None) is not None
|
||||||
|
|
||||||
@ -567,14 +572,16 @@ class StateMachine(object):
|
|||||||
from_state = _process_match_param(from_state)
|
from_state = _process_match_param(from_state)
|
||||||
to_state = _process_match_param(to_state)
|
to_state = _process_match_param(to_state)
|
||||||
|
|
||||||
# Ensure it is a list with entity ids we want to match on
|
# Ensure it is a lowercase list with entity ids we want to match on
|
||||||
if isinstance(entity_ids, str):
|
if isinstance(entity_ids, str):
|
||||||
entity_ids = [entity_ids]
|
entity_ids = [entity_ids.lower()]
|
||||||
|
else:
|
||||||
|
entity_ids = [entity_id.lower() for entity_id in entity_ids]
|
||||||
|
|
||||||
@ft.wraps(action)
|
@ft.wraps(action)
|
||||||
def state_listener(event):
|
def state_listener(event):
|
||||||
""" The listener that listens for specific state changes. """
|
""" The listener that listens for specific state changes. """
|
||||||
if event.data['entity_id'] in entity_ids and \
|
if event.data['entity_id'].lower() in entity_ids and \
|
||||||
'old_state' in event.data and \
|
'old_state' in event.data and \
|
||||||
_matcher(event.data['old_state'].state, from_state) and \
|
_matcher(event.data['old_state'].state, from_state) and \
|
||||||
_matcher(event.data['new_state'].state, to_state):
|
_matcher(event.data['new_state'].state, to_state):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user