StateMachine is now case insensitive for entity ids

This commit is contained in:
Paulus Schoutsen 2014-12-26 23:26:39 -08:00
parent df3ddea23c
commit 89a548252a
2 changed files with 28 additions and 8 deletions

View File

@ -285,6 +285,19 @@ class TestStateMachine(unittest.TestCase):
self.assertEqual(1, len(specific_runs)) self.assertEqual(1, len(specific_runs))
self.assertEqual(3, len(wildcard_runs)) self.assertEqual(3, len(wildcard_runs))
def test_case_insensitivty(self):
runs = []
self.states.track_change(
'light.BoWl', lambda a, b, c: runs.append(1),
ha.MATCH_ALL, ha.MATCH_ALL)
self.states.set('light.BOWL', 'off')
self.bus._pool.block_till_done()
self.assertTrue(self.states.is_state('light.bowl', 'off'))
self.assertEqual(1, len(runs))
class TestServiceCall(unittest.TestCase): class TestServiceCall(unittest.TestCase):
""" Test ServiceCall class. """ """ Test ServiceCall class. """

View File

@ -15,6 +15,8 @@ import re
import datetime as dt import datetime as dt
import functools as ft import functools as ft
from requests.structures import CaseInsensitiveDict
from homeassistant.const import ( from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED, SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
@ -482,15 +484,18 @@ class StateMachine(object):
""" Helper class that tracks the state of different entities. """ """ Helper class that tracks the state of different entities. """
def __init__(self, bus): def __init__(self, bus):
self._states = {} self._states = CaseInsensitiveDict()
self._bus = bus self._bus = bus
self._lock = threading.Lock() self._lock = threading.Lock()
def entity_ids(self, domain_filter=None): def entity_ids(self, domain_filter=None):
""" List of entity ids that are being tracked. """ """ List of entity ids that are being tracked. """
if domain_filter is not None: if domain_filter is not None:
return [entity_id for entity_id in self._states.keys() domain_filter = domain_filter.lower()
if util.split_entity_id(entity_id)[0] == domain_filter]
return [state.entity_id for key, state
in self._states.lower_items()
if util.split_entity_id(key)[0] == domain_filter]
else: else:
return list(self._states.keys()) return list(self._states.keys())
@ -524,9 +529,9 @@ class StateMachine(object):
self._states[entity_id].state == state) self._states[entity_id].state == state)
def remove(self, entity_id): def remove(self, entity_id):
""" Removes a entity from the state machine. """ Removes an entity from the state machine.
Returns boolean to indicate if a entity was removed. """ Returns boolean to indicate if an entity was removed. """
with self._lock: with self._lock:
return self._states.pop(entity_id, None) is not None return self._states.pop(entity_id, None) is not None
@ -567,14 +572,16 @@ class StateMachine(object):
from_state = _process_match_param(from_state) from_state = _process_match_param(from_state)
to_state = _process_match_param(to_state) to_state = _process_match_param(to_state)
# Ensure it is a list with entity ids we want to match on # Ensure it is a lowercase list with entity ids we want to match on
if isinstance(entity_ids, str): if isinstance(entity_ids, str):
entity_ids = [entity_ids] entity_ids = [entity_ids.lower()]
else:
entity_ids = [entity_id.lower() for entity_id in entity_ids]
@ft.wraps(action) @ft.wraps(action)
def state_listener(event): def state_listener(event):
""" The listener that listens for specific state changes. """ """ The listener that listens for specific state changes. """
if event.data['entity_id'] in entity_ids and \ if event.data['entity_id'].lower() in entity_ids and \
'old_state' in event.data and \ 'old_state' in event.data and \
_matcher(event.data['old_state'].state, from_state) and \ _matcher(event.data['old_state'].state, from_state) and \
_matcher(event.data['new_state'].state, to_state): _matcher(event.data['new_state'].state, to_state):